{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Base Dependencies\n",
    "import os\n",
    "import pickle\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# LinAlg / Stats / Plotting Dependencies\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option(\"display.precision\", 3)\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Scikit-Learn Imports\n",
    "import sklearn\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.utils._testing import ignore_warnings\n",
    "from sklearn.exceptions import ConvergenceWarning\n",
    "\n",
    "# Utils\n",
    "from patch_evaluation_utils import kendalltau_bpq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### How To Use\n",
    "1. Create the \"embeddings_patch_library\" using \"patch_extraction.py\"\n",
    "3. Run this notebook!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CRC-100K (Without SN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACC 0.7080779944289693\n",
      "ACC 0.6844011142061281\n",
      "ACC 0.873816155988858\n",
      "ACC 0.7615598885793872\n",
      "ACC 0.8002785515320334\n",
      "ACC 0.8194986072423398\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ADI</th>\n",
       "      <th>BACK</th>\n",
       "      <th>DEB</th>\n",
       "      <th>LYM</th>\n",
       "      <th>MUC</th>\n",
       "      <th>NORM</th>\n",
       "      <th>STR</th>\n",
       "      <th>TUM</th>\n",
       "      <th>All</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>beit</th>\n",
       "      <td>0.963</td>\n",
       "      <td>0.648</td>\n",
       "      <td>0.930</td>\n",
       "      <td>0.905</td>\n",
       "      <td>0.842</td>\n",
       "      <td>0.942</td>\n",
       "      <td>0.943</td>\n",
       "      <td>0.960</td>\n",
       "      <td>0.892</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>beit_90</th>\n",
       "      <td>0.960</td>\n",
       "      <td>0.537</td>\n",
       "      <td>0.916</td>\n",
       "      <td>0.927</td>\n",
       "      <td>0.792</td>\n",
       "      <td>0.949</td>\n",
       "      <td>0.939</td>\n",
       "      <td>0.963</td>\n",
       "      <td>0.873</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>beit_imagenet</th>\n",
       "      <td>0.991</td>\n",
       "      <td>0.908</td>\n",
       "      <td>0.985</td>\n",
       "      <td>0.985</td>\n",
       "      <td>0.967</td>\n",
       "      <td>0.944</td>\n",
       "      <td>0.987</td>\n",
       "      <td>0.974</td>\n",
       "      <td>0.968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>resnet50_trunc</th>\n",
       "      <td>0.988</td>\n",
       "      <td>0.909</td>\n",
       "      <td>0.900</td>\n",
       "      <td>0.870</td>\n",
       "      <td>0.886</td>\n",
       "      <td>0.988</td>\n",
       "      <td>0.963</td>\n",
       "      <td>0.978</td>\n",
       "      <td>0.935</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SimCLR (BRCA)</th>\n",
       "      <td>0.981</td>\n",
       "      <td>0.765</td>\n",
       "      <td>0.955</td>\n",
       "      <td>0.951</td>\n",
       "      <td>0.926</td>\n",
       "      <td>0.976</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.973</td>\n",
       "      <td>0.938</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>DINO (BRCA)</th>\n",
       "      <td>0.991</td>\n",
       "      <td>0.729</td>\n",
       "      <td>0.961</td>\n",
       "      <td>0.950</td>\n",
       "      <td>0.978</td>\n",
       "      <td>0.957</td>\n",
       "      <td>0.990</td>\n",
       "      <td>0.973</td>\n",
       "      <td>0.941</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  ADI   BACK    DEB    LYM    MUC   NORM    STR    TUM    All\n",
       "beit            0.963  0.648  0.930  0.905  0.842  0.942  0.943  0.960  0.892\n",
       "beit_90         0.960  0.537  0.916  0.927  0.792  0.949  0.939  0.963  0.873\n",
       "beit_imagenet   0.991  0.908  0.985  0.985  0.967  0.944  0.987  0.974  0.968\n",
       "resnet50_trunc  0.988  0.909  0.900  0.870  0.886  0.988  0.963  0.978  0.935\n",
       "SimCLR (BRCA)   0.981  0.765  0.955  0.951  0.926  0.976  0.979  0.973  0.938\n",
       "DINO (BRCA)     0.991  0.729  0.961  0.950  0.978  0.957  0.990  0.973  0.941"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "crc100k_nonorm_aucs_all = {}\n",
    "models = ['beit',\n",
    "          'beit_90',\n",
    "          'beit_imagenet',\n",
    "          'resnet50_trunc',\n",
    "          'resnet50_tcga_brca_simclr',\n",
    "          'vits_tcga_brca_dino',\n",
    "        ]\n",
    "model_names = ['beit',\n",
    "               'beit_90',\n",
    "          'beit_imagenet',\n",
    "              'resnet50_trunc',\n",
    "               'SimCLR (BRCA)',\n",
    "               'DINO (BRCA)',\n",
    "              ]\n",
    "\n",
    "for enc in models:\n",
    "    train_fname = os.path.join('./embeddings_patch_library/', 'crc100knonorm_train_%s.pkl' % enc)\n",
    "    with open(train_fname, 'rb') as handle:\n",
    "        asset_dict = pickle.load(handle)\n",
    "        train_embeddings, train_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "\n",
    "    val_fname = os.path.join('./embeddings_patch_library/', 'crc100k_val_%s.pkl' % enc)\n",
    "    with open(val_fname, 'rb') as handle:\n",
    "        asset_dict = pickle.load(handle)\n",
    "        val_embeddings, val_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "\n",
    "    train_labels[train_labels=='MUS'] = 'STR'\n",
    "    val_labels[val_labels=='MUS'] = 'STR'\n",
    "    le = LabelEncoder().fit(train_labels)\n",
    "    train_labels = le.transform(train_labels)\n",
    "    val_labels = le.transform(val_labels)\n",
    "    \n",
    "    if enc in crc100k_nonorm_aucs_all.keys():\n",
    "        pass\n",
    "    else:\n",
    "        clf = KNeighborsClassifier().fit(train_embeddings, train_labels)\n",
    "        y_score = clf.predict_proba(val_embeddings)\n",
    "        y_pred = clf.predict(val_embeddings)\n",
    "        aucs, f1s = [], []\n",
    "        for i, label in enumerate(np.unique(val_labels)):\n",
    "            label_class = np.array(val_labels == label, int)\n",
    "            aucs.append(sklearn.metrics.roc_auc_score(label_class, y_score[:,i]))\n",
    "        aucs.append(sklearn.metrics.roc_auc_score(val_labels, y_score, average='macro', multi_class='ovr'))\n",
    "        crc100k_nonorm_aucs_all[enc] = aucs\n",
    "        print('ACC',np.mean(val_labels == y_pred))\n",
    "aucs_df = pd.DataFrame(crc100k_nonorm_aucs_all).T.loc[models]\n",
    "aucs_df.index = model_names\n",
    "aucs_df.columns = ['ADI', 'BACK', 'DEB', 'LYM', 'MUC', 'NORM', 'STR', 'TUM', 'All']\n",
    "crc100kr = aucs_df.copy()\n",
    "crc100kr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['beit',\n",
       " 'beit_90',\n",
       " 'beit_imagenet',\n",
       " 'resnet50_trunc',\n",
       " 'SimCLR (BRCA)',\n",
       " 'DINO (BRCA)']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_fname = os.path.join('./embeddings_patch_library/', 'crc100knonorm_train_%s.pkl' % enc)\n",
    "with open(train_fname, 'rb') as handle:\n",
    "    asset_dict = pickle.load(handle)\n",
    "    train_embeddings, train_labels = asset_dict['embeddings'], asset_dict['labels']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CRC-100K (With SN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ACC 0.8881615598885794\n",
      "ACC 0.9338440111420613\n",
      "ACC 0.8771587743732591\n",
      "ACC 0.9071030640668524\n",
      "ACC 0.8991643454038997\n",
      "ACC 0.9363509749303621\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ADI</th>\n",
       "      <th>BACK</th>\n",
       "      <th>DEB</th>\n",
       "      <th>LYM</th>\n",
       "      <th>MUC</th>\n",
       "      <th>NORM</th>\n",
       "      <th>STR</th>\n",
       "      <th>TUM</th>\n",
       "      <th>All</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>beit</th>\n",
       "      <td>0.986</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.991</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.923</td>\n",
       "      <td>0.972</td>\n",
       "      <td>0.959</td>\n",
       "      <td>0.974</td>\n",
       "      <td>0.973</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>beit_90</th>\n",
       "      <td>0.996</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.999</td>\n",
       "      <td>0.993</td>\n",
       "      <td>0.983</td>\n",
       "      <td>0.957</td>\n",
       "      <td>0.991</td>\n",
       "      <td>0.977</td>\n",
       "      <td>0.987</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>beit_imagenet</th>\n",
       "      <td>0.989</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.997</td>\n",
       "      <td>0.993</td>\n",
       "      <td>0.899</td>\n",
       "      <td>0.967</td>\n",
       "      <td>0.960</td>\n",
       "      <td>0.973</td>\n",
       "      <td>0.972</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>resnet50_trunc</th>\n",
       "      <td>0.983</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.997</td>\n",
       "      <td>0.974</td>\n",
       "      <td>0.963</td>\n",
       "      <td>0.988</td>\n",
       "      <td>0.982</td>\n",
       "      <td>0.978</td>\n",
       "      <td>0.983</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SimCLR (BRCA)</th>\n",
       "      <td>0.988</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.994</td>\n",
       "      <td>0.980</td>\n",
       "      <td>0.969</td>\n",
       "      <td>0.973</td>\n",
       "      <td>0.979</td>\n",
       "      <td>0.969</td>\n",
       "      <td>0.981</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>DINO (BRCA)</th>\n",
       "      <td>0.999</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.999</td>\n",
       "      <td>0.985</td>\n",
       "      <td>0.992</td>\n",
       "      <td>0.960</td>\n",
       "      <td>0.992</td>\n",
       "      <td>0.967</td>\n",
       "      <td>0.987</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  ADI  BACK    DEB    LYM    MUC   NORM    STR    TUM    All\n",
       "beit            0.986   1.0  0.991  0.979  0.923  0.972  0.959  0.974  0.973\n",
       "beit_90         0.996   1.0  0.999  0.993  0.983  0.957  0.991  0.977  0.987\n",
       "beit_imagenet   0.989   1.0  0.997  0.993  0.899  0.967  0.960  0.973  0.972\n",
       "resnet50_trunc  0.983   1.0  0.997  0.974  0.963  0.988  0.982  0.978  0.983\n",
       "SimCLR (BRCA)   0.988   1.0  0.994  0.980  0.969  0.973  0.979  0.969  0.981\n",
       "DINO (BRCA)     0.999   1.0  0.999  0.985  0.992  0.960  0.992  0.967  0.987"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "crc100k_aucs_all = {}\n",
    "models = ['beit',\n",
    "          'beit_90',\n",
    "          'beit_imagenet',\n",
    "          \n",
    "          'resnet50_trunc',\n",
    "          'resnet50_tcga_brca_simclr',\n",
    "          'vits_tcga_brca_dino',\n",
    "        ]\n",
    "model_names = ['beit',\n",
    "            'beit_90',\n",
    "          'beit_imagenet',\n",
    "              'resnet50_trunc',\n",
    "               'SimCLR (BRCA)',\n",
    "               'DINO (BRCA)',\n",
    "              ]\n",
    "for enc in models:\n",
    "    train_fname = os.path.join('./embeddings_patch_library/', 'crc100k_train_%s.pkl' % enc)\n",
    "    with open(train_fname, 'rb') as handle:\n",
    "        asset_dict = pickle.load(handle)\n",
    "        train_embeddings, train_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "\n",
    "    val_fname = os.path.join('./embeddings_patch_library/', 'crc100k_val_%s.pkl' % enc)\n",
    "    with open(val_fname, 'rb') as handle:\n",
    "        asset_dict = pickle.load(handle)\n",
    "        val_embeddings, val_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "\n",
    "    train_labels[train_labels=='MUS'] = 'STR'\n",
    "    val_labels[val_labels=='MUS'] = 'STR'\n",
    "    le = LabelEncoder().fit(train_labels)\n",
    "    train_labels = le.transform(train_labels)\n",
    "    val_labels = le.transform(val_labels)\n",
    "    \n",
    "    if enc in crc100k_aucs_all.keys():\n",
    "        pass\n",
    "    else:\n",
    "        clf = KNeighborsClassifier().fit(train_embeddings, train_labels)\n",
    "        y_score = clf.predict_proba(val_embeddings)\n",
    "        y_pred = clf.predict(val_embeddings)\n",
    "        aucs, f1s = [], []\n",
    "        for i, label in enumerate(np.unique(val_labels)):\n",
    "            label_class = np.array(val_labels == label, int)\n",
    "            aucs.append(sklearn.metrics.roc_auc_score(label_class, y_score[:,i]))\n",
    "        aucs.append(sklearn.metrics.roc_auc_score(val_labels, y_score, average='macro', multi_class='ovr'))\n",
    "        crc100k_aucs_all[enc] = aucs\n",
    "        print('ACC',np.mean(val_labels == y_pred))\n",
    "aucs_df = pd.DataFrame(crc100k_aucs_all).T.loc[models]\n",
    "aucs_df.index = model_names\n",
    "aucs_df.columns = ['ADI', 'BACK', 'DEB', 'LYM', 'MUC', 'NORM', 'STR', 'TUM', 'All']\n",
    "crc100kn = aucs_df.copy()\n",
    "crc100kn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "weight1 = torch.load('/dssg/home/acct-medftn/medftn/BEPT/Model/mmselfsup/TCGA_Checkpoints/beitv2_backbone_imagenet1k.pth')\n",
    "weight2 = torch.load('/dssg/home/acct-medftn/medftn/BEPT/Model/mmselfsup/TCGA_Checkpoints/beitv2_vit-base_imagenet.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['backbone.cls_token', 'backbone.mask_token', 'backbone.patch_embed.projection.weight', 'backbone.patch_embed.projection.bias', 'backbone.rel_pos_bias.relative_position_bias_table', 'backbone.rel_pos_bias.relative_position_index', 'backbone.layers.0.gamma_1', 'backbone.layers.0.gamma_2', 'backbone.layers.0.ln1.weight', 'backbone.layers.0.ln1.bias', 'backbone.layers.0.attn.q_bias', 'backbone.layers.0.attn.v_bias', 'backbone.layers.0.attn.qkv.weight', 'backbone.layers.0.attn.proj.weight', 'backbone.layers.0.attn.proj.bias', 'backbone.layers.0.ln2.weight', 'backbone.layers.0.ln2.bias', 'backbone.layers.0.ffn.layers.0.0.weight', 'backbone.layers.0.ffn.layers.0.0.bias', 'backbone.layers.0.ffn.layers.1.weight', 'backbone.layers.0.ffn.layers.1.bias', 'backbone.layers.1.gamma_1', 'backbone.layers.1.gamma_2', 'backbone.layers.1.ln1.weight', 'backbone.layers.1.ln1.bias', 'backbone.layers.1.attn.q_bias', 'backbone.layers.1.attn.v_bias', 'backbone.layers.1.attn.qkv.weight', 'backbone.layers.1.attn.proj.weight', 'backbone.layers.1.attn.proj.bias', 'backbone.layers.1.ln2.weight', 'backbone.layers.1.ln2.bias', 'backbone.layers.1.ffn.layers.0.0.weight', 'backbone.layers.1.ffn.layers.0.0.bias', 'backbone.layers.1.ffn.layers.1.weight', 'backbone.layers.1.ffn.layers.1.bias', 'backbone.layers.2.gamma_1', 'backbone.layers.2.gamma_2', 'backbone.layers.2.ln1.weight', 'backbone.layers.2.ln1.bias', 'backbone.layers.2.attn.q_bias', 'backbone.layers.2.attn.v_bias', 'backbone.layers.2.attn.qkv.weight', 'backbone.layers.2.attn.proj.weight', 'backbone.layers.2.attn.proj.bias', 'backbone.layers.2.ln2.weight', 'backbone.layers.2.ln2.bias', 'backbone.layers.2.ffn.layers.0.0.weight', 'backbone.layers.2.ffn.layers.0.0.bias', 'backbone.layers.2.ffn.layers.1.weight', 'backbone.layers.2.ffn.layers.1.bias', 'backbone.layers.3.gamma_1', 'backbone.layers.3.gamma_2', 'backbone.layers.3.ln1.weight', 'backbone.layers.3.ln1.bias', 'backbone.layers.3.attn.q_bias', 'backbone.layers.3.attn.v_bias', 'backbone.layers.3.attn.qkv.weight', 'backbone.layers.3.attn.proj.weight', 'backbone.layers.3.attn.proj.bias', 'backbone.layers.3.ln2.weight', 'backbone.layers.3.ln2.bias', 'backbone.layers.3.ffn.layers.0.0.weight', 'backbone.layers.3.ffn.layers.0.0.bias', 'backbone.layers.3.ffn.layers.1.weight', 'backbone.layers.3.ffn.layers.1.bias', 'backbone.layers.4.gamma_1', 'backbone.layers.4.gamma_2', 'backbone.layers.4.ln1.weight', 'backbone.layers.4.ln1.bias', 'backbone.layers.4.attn.q_bias', 'backbone.layers.4.attn.v_bias', 'backbone.layers.4.attn.qkv.weight', 'backbone.layers.4.attn.proj.weight', 'backbone.layers.4.attn.proj.bias', 'backbone.layers.4.ln2.weight', 'backbone.layers.4.ln2.bias', 'backbone.layers.4.ffn.layers.0.0.weight', 'backbone.layers.4.ffn.layers.0.0.bias', 'backbone.layers.4.ffn.layers.1.weight', 'backbone.layers.4.ffn.layers.1.bias', 'backbone.layers.5.gamma_1', 'backbone.layers.5.gamma_2', 'backbone.layers.5.ln1.weight', 'backbone.layers.5.ln1.bias', 'backbone.layers.5.attn.q_bias', 'backbone.layers.5.attn.v_bias', 'backbone.layers.5.attn.qkv.weight', 'backbone.layers.5.attn.proj.weight', 'backbone.layers.5.attn.proj.bias', 'backbone.layers.5.ln2.weight', 'backbone.layers.5.ln2.bias', 'backbone.layers.5.ffn.layers.0.0.weight', 'backbone.layers.5.ffn.layers.0.0.bias', 'backbone.layers.5.ffn.layers.1.weight', 'backbone.layers.5.ffn.layers.1.bias', 'backbone.layers.6.gamma_1', 'backbone.layers.6.gamma_2', 'backbone.layers.6.ln1.weight', 'backbone.layers.6.ln1.bias', 'backbone.layers.6.attn.q_bias', 'backbone.layers.6.attn.v_bias', 'backbone.layers.6.attn.qkv.weight', 'backbone.layers.6.attn.proj.weight', 'backbone.layers.6.attn.proj.bias', 'backbone.layers.6.ln2.weight', 'backbone.layers.6.ln2.bias', 'backbone.layers.6.ffn.layers.0.0.weight', 'backbone.layers.6.ffn.layers.0.0.bias', 'backbone.layers.6.ffn.layers.1.weight', 'backbone.layers.6.ffn.layers.1.bias', 'backbone.layers.7.gamma_1', 'backbone.layers.7.gamma_2', 'backbone.layers.7.ln1.weight', 'backbone.layers.7.ln1.bias', 'backbone.layers.7.attn.q_bias', 'backbone.layers.7.attn.v_bias', 'backbone.layers.7.attn.qkv.weight', 'backbone.layers.7.attn.proj.weight', 'backbone.layers.7.attn.proj.bias', 'backbone.layers.7.ln2.weight', 'backbone.layers.7.ln2.bias', 'backbone.layers.7.ffn.layers.0.0.weight', 'backbone.layers.7.ffn.layers.0.0.bias', 'backbone.layers.7.ffn.layers.1.weight', 'backbone.layers.7.ffn.layers.1.bias', 'backbone.layers.8.gamma_1', 'backbone.layers.8.gamma_2', 'backbone.layers.8.ln1.weight', 'backbone.layers.8.ln1.bias', 'backbone.layers.8.attn.q_bias', 'backbone.layers.8.attn.v_bias', 'backbone.layers.8.attn.qkv.weight', 'backbone.layers.8.attn.proj.weight', 'backbone.layers.8.attn.proj.bias', 'backbone.layers.8.ln2.weight', 'backbone.layers.8.ln2.bias', 'backbone.layers.8.ffn.layers.0.0.weight', 'backbone.layers.8.ffn.layers.0.0.bias', 'backbone.layers.8.ffn.layers.1.weight', 'backbone.layers.8.ffn.layers.1.bias', 'backbone.layers.9.gamma_1', 'backbone.layers.9.gamma_2', 'backbone.layers.9.ln1.weight', 'backbone.layers.9.ln1.bias', 'backbone.layers.9.attn.q_bias', 'backbone.layers.9.attn.v_bias', 'backbone.layers.9.attn.qkv.weight', 'backbone.layers.9.attn.proj.weight', 'backbone.layers.9.attn.proj.bias', 'backbone.layers.9.ln2.weight', 'backbone.layers.9.ln2.bias', 'backbone.layers.9.ffn.layers.0.0.weight', 'backbone.layers.9.ffn.layers.0.0.bias', 'backbone.layers.9.ffn.layers.1.weight', 'backbone.layers.9.ffn.layers.1.bias', 'backbone.layers.10.gamma_1', 'backbone.layers.10.gamma_2', 'backbone.layers.10.ln1.weight', 'backbone.layers.10.ln1.bias', 'backbone.layers.10.attn.q_bias', 'backbone.layers.10.attn.v_bias', 'backbone.layers.10.attn.qkv.weight', 'backbone.layers.10.attn.proj.weight', 'backbone.layers.10.attn.proj.bias', 'backbone.layers.10.ln2.weight', 'backbone.layers.10.ln2.bias', 'backbone.layers.10.ffn.layers.0.0.weight', 'backbone.layers.10.ffn.layers.0.0.bias', 'backbone.layers.10.ffn.layers.1.weight', 'backbone.layers.10.ffn.layers.1.bias', 'backbone.layers.11.gamma_1', 'backbone.layers.11.gamma_2', 'backbone.layers.11.ln1.weight', 'backbone.layers.11.ln1.bias', 'backbone.layers.11.attn.q_bias', 'backbone.layers.11.attn.v_bias', 'backbone.layers.11.attn.qkv.weight', 'backbone.layers.11.attn.proj.weight', 'backbone.layers.11.attn.proj.bias', 'backbone.layers.11.ln2.weight', 'backbone.layers.11.ln2.bias', 'backbone.layers.11.ffn.layers.0.0.weight', 'backbone.layers.11.ffn.layers.0.0.bias', 'backbone.layers.11.ffn.layers.1.weight', 'backbone.layers.11.ffn.layers.1.bias'])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight2['state_dict'].keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 线性分类协议"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "enc = 'beit'\n",
    "train_fname = os.path.join('./embeddings_patch_library/', 'crc100k_train_%s.pkl' % enc)\n",
    "with open(train_fname, 'rb') as handle:\n",
    "    asset_dict = pickle.load(handle)\n",
    "    train_embeddings, train_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "val_fname = os.path.join('./embeddings_patch_library/', 'crc100k_val_%s.pkl' % enc)\n",
    "with open(val_fname, 'rb') as handle:\n",
    "    asset_dict = pickle.load(handle)\n",
    "    val_embeddings, val_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "le = LabelEncoder().fit(train_labels)\n",
    "train_labels = le.transform(train_labels)\n",
    "val_labels = le.transform(val_labels)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100000, 768)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_embeddings.shape\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "\n",
    "# 定义线性分类器神经网络\n",
    "class LinearClassifier(nn.Module):\n",
    "    def __init__(self, input_size, num_classes):\n",
    "        super(LinearClassifier, self).__init__()\n",
    "        \n",
    "        # 线性层\n",
    "        self.linear1 = nn.Linear(input_size, 1000)\n",
    "        self.linear2 = nn.Linear(1000, 9)\n",
    "        # Softmax激活函数\n",
    "        self.softmax = nn.Softmax()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        # 前向传播\n",
    "        x = self.linear1(x)\n",
    "\n",
    "        x = self.linear2(x)\n",
    "        x = self.softmax(x)\n",
    "        return x\n",
    "# import torch\n",
    "# import torch.nn as nn\n",
    "# import torchvision\n",
    "# # 定义线性分类协议网络\n",
    "# class LinearProtocolNet(nn.Module):\n",
    "#     def __init__(self, pretrained_model, num_classes):\n",
    "#         super(LinearProtocolNet, self).__init__()\n",
    "        \n",
    "#         # 冻结预训练模型的参数\n",
    "#         for param in pretrained_model.parameters():\n",
    "#             param.requires_grad = False\n",
    "        \n",
    "#         # 获取预训练模型的输出特征维度\n",
    "#         pretrained_features = pretrained_model.fc.in_features\n",
    "        \n",
    "#         # 添加线性分类器层\n",
    "#         self.classifier = nn.Linear(pretrained_features, num_classes)\n",
    "        \n",
    "#     def forward(self, x):\n",
    "#         # 前向传播\n",
    "#         x = self.classifier(x)\n",
    "#         return x\n",
    "\n",
    "# 示例用法\n",
    "# 假设预训练模型为resnet18，输出类别数为10\n",
    "# pretrained_model = torchvision.models.resnet18(pretrained=True)\n",
    "num_classes = 9\n",
    "\n",
    "# 创建线性分类协议网络\n",
    "# linear_classifier = LinearProtocolNet(pretrained_model, num_classes)\n",
    "# print(linear_protocol_net)\n",
    "# 打印网络结构\n",
    "\n",
    "# 创建线性分类器神经网络\n",
    "input_size = 768\n",
    "num_classes = 9\n",
    "linear_classifier = LinearClassifier(input_size, num_classes).to('cuda')\n",
    "\n",
    "# 创建虚拟数据\n",
    "\n",
    "input_data = torch.tensor(train_embeddings).to('cuda')\n",
    "labels = torch.tensor(train_labels).to('cuda')\n",
    "\n",
    "# 定义损失函数和优化器\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "# optimizer = optim.SGD(linear_classifier.parameters(), lr=5)\n",
    "optimizer =  optim.Adam(linear_classifier.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)\n",
    "valinput_data = torch.tensor(val_embeddings).to('cuda')\n",
    "labels_test = torch.tensor(val_labels).to('cuda')\n",
    "\n",
    "\n",
    "# 进行训练\n",
    "num_epochs = 100\n",
    "for epoch in range(num_epochs):\n",
    "    # 前向传播\n",
    "    outputs = linear_classifier(input_data)\n",
    "    \n",
    "    # 计算损失\n",
    "    loss = criterion(outputs, labels)\n",
    "    predicted_train = torch.max(outputs.data, 1)\n",
    "    accuracy_train = ((predicted_train.indices == labels).to(float)).sum().item() / labels.shape[0]\n",
    "    # 反向传播和优化\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    # 打印训练信息\n",
    "    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()},acc_train:{accuracy_train}')\n",
    "    linear_classifier.eval()\n",
    "    with torch.no_grad():\n",
    "        outputs_test = linear_classifier(valinput_data)\n",
    "        _, predicted_test = torch.max(outputs_test.data, 1)\n",
    "        accuracy_test = (predicted_test == labels_test).sum().item() / valinput_data.shape[0]\n",
    "        print(accuracy_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "predicted_train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BreastPathQ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bpq_mse_all = []\n",
    "models = ['resnet50_trunc',\n",
    "          'resnet50_tcga_brca_simclr',\n",
    "          'vits_tcga_brca_dino',\n",
    "        ]\n",
    "model_names = ['ImageNet',\n",
    "               'SimCLR (BRCA)',\n",
    "               'DINO (BRCA)',\n",
    "              ]\n",
    "\n",
    "for enc in models:\n",
    "    train_fname = os.path.join('./embeddings_patch_library/', 'breastpathq_train_%s.pkl' % enc)\n",
    "    with open(train_fname, 'rb') as handle:\n",
    "        asset_dict = pickle.load(handle)\n",
    "        train_embeddings, train_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "\n",
    "    val_fname = os.path.join('./embeddings_patch_library/', 'breastpathq_val_%s.pkl' % enc)\n",
    "    with open(val_fname, 'rb') as handle:\n",
    "        asset_dict = pickle.load(handle)\n",
    "        val_embeddings, val_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "    \n",
    "    clf = LinearRegression().fit(train_embeddings, train_labels)\n",
    "    y_score = clf.predict(val_embeddings)\n",
    "    bpq_mse_all.append([sklearn.metrics.mean_squared_error(val_labels, y_score), kendalltau_bpq(val_labels, y_score)])\n",
    "\n",
    "mse_df = pd.DataFrame(bpq_mse_all)\n",
    "mse_df.columns = ['MSE', 'Tau']\n",
    "mse_df.index = model_names\n",
    "bpq = mse_df.copy()\n",
    "bpq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BCSS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bcss_aucs_all = {}\n",
    "models = ['resnet50_trunc',\n",
    "          'resnet50_tcga_brca_simclr',\n",
    "          'vits_tcga_brca_dino',\n",
    "        ]\n",
    "model_names = ['ImageNet',\n",
    "               'SimCLR (BRCA)',\n",
    "               'DINO (BRCA)',\n",
    "              ]\n",
    "\n",
    "for enc in models:\n",
    "    train_fname = './embeddings_patch_library/bcss_train_%s.pkl' % enc\n",
    "    with open(train_fname, 'rb') as handle:\n",
    "        asset_dict = pickle.load(handle)\n",
    "        train_embeddings, train_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "\n",
    "    val_fname = './embeddings_patch_library/bcss_val_%s.pkl' % enc\n",
    "    with open(val_fname, 'rb') as handle:\n",
    "        asset_dict = pickle.load(handle)\n",
    "        val_embeddings, val_labels = asset_dict['embeddings'], asset_dict['labels']\n",
    "    \n",
    "    if enc in bcss_aucs_all.keys():\n",
    "        pass\n",
    "    else:\n",
    "        clf = KNeighborsClassifier().fit(train_embeddings, train_labels)\n",
    "        y_score = clf.predict_proba(val_embeddings)\n",
    "        y_pred = clf.predict(val_embeddings)\n",
    "        aucs, f1s = [], []\n",
    "        for i, label in enumerate(np.unique(val_labels)):\n",
    "            label_class = np.array(val_labels == label, int)\n",
    "            aucs.append(sklearn.metrics.roc_auc_score(label_class, y_score[:,i]))\n",
    "        aucs.append(sklearn.metrics.roc_auc_score(val_labels, y_score, average='macro', multi_class='ovr'))\n",
    "        bcss_aucs_all[enc] = aucs\n",
    "        \n",
    "aucs_df = pd.DataFrame(bcss_aucs_all).T.loc[models]\n",
    "aucs_df.index = model_names\n",
    "aucs_df.columns = list(np.unique(train_labels)) + ['All']\n",
    "bcss = aucs_df.copy()\n",
    "bcss"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mmselfsup_yzc",
   "language": "python",
   "name": "mmselfsup_yzc"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
